{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "93075c79-81cf-42ec-9931-12443efaaf4b",
   "metadata": {},
   "source": [
    "# Lab 5 - Embedding Adaptors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a5536f0-651c-40e7-aa15-27ee0cda80b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from helper_utils import load_chroma, word_wrap, project_embeddings\n",
    "from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction\n",
    "import numpy as np\n",
    "import umap\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3748b16d-d4a7-49c3-a48a-57dcfc42acd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_function = SentenceTransformerEmbeddingFunction()\n",
    "\n",
    "chroma_collection = load_chroma(filename='microsoft_annual_report_2022.pdf', collection_name='microsoft_annual_report_2022', embedding_function=embedding_function)\n",
    "chroma_collection.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a338ec83-6301-41a5-9ab1-e5d583306a3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = chroma_collection.get(include=['embeddings'])['embeddings']\n",
    "umap_transform = umap.UMAP(random_state=0, transform_seed=0).fit(embeddings)\n",
    "projected_dataset_embeddings = project_embeddings(embeddings, umap_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5665c695-22ea-4264-b1ac-5ba720b6d78b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "from openai import OpenAI\n",
    "\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "openai.api_key = os.environ['OPENAI_API_KEY']\n",
    "\n",
    "openai_client = OpenAI()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a34ff415-8d20-4171-9d40-7eb1bbe837a1",
   "metadata": {},
   "source": [
    "## Creating a dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ba6c8c5-9ce4-44d0-9223-6fdd77871f87",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_queries(model=\"gpt-3.5-turbo\"):\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful expert financial research assistant. You help users analyze financial statements to better understand companies. \"\n",
    "            \"Suggest 10 to 15 short questions that are important to ask when analyzing an annual report. \"\n",
    "            \"Do not output any compound questions (questions with multiple sentences or conjunctions).\"\n",
    "            \"Output each question on a separate line divided by a newline.\"\n",
    "        },\n",
    "    ]\n",
    "\n",
    "    response = openai_client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "    )\n",
    "    content = response.choices[0].message.content\n",
    "    content = content.split(\"\\n\")\n",
    "    return content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfdb54db-a442-423c-b006-c33a257cd7d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_queries = generate_queries()\n",
    "for query in generated_queries:\n",
    "    print(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "377a84aa-1d93-4e97-9b2d-d59c46355338",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = chroma_collection.query(query_texts=generated_queries, n_results=10, include=['documents', 'embeddings'])\n",
    "retrieved_documents = results['documents']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba0ed8ca-6640-4c09-9cb3-9de5e7cf46dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_results(query, statement, model=\"gpt-3.5-turbo\"):\n",
    "    messages = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"You are a helpful expert financial research assistant. You help users analyze financial statements to better understand companies. \"\n",
    "        \"For the given query, evaluate whether the following satement is relevant.\"\n",
    "        \"Output only 'yes' or 'no'.\"\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"Query: {query}, Statement: {statement}\"\n",
    "    }\n",
    "    ]\n",
    "\n",
    "    response = openai_client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        max_tokens=1\n",
    "    )\n",
    "    content = response.choices[0].message.content\n",
    "    if content == \"yes\":\n",
    "        return 1\n",
    "    return -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28bac3a2-0d29-48dc-9b48-2d9313239a25",
   "metadata": {},
   "outputs": [],
   "source": [
    "retrieved_embeddings = results['embeddings']\n",
    "query_embeddings = embedding_function(generated_queries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db9f2758-0f5a-49e5-b1fa-517b91324575",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_query_embeddings = []\n",
    "adapter_doc_embeddings = []\n",
    "adapter_labels = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aee59493-8a99-4da8-b94f-4747efcfc79d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for q, query in enumerate(tqdm(generated_queries)):\n",
    "    for d, document in enumerate(retrieved_documents[q]):\n",
    "        adapter_query_embeddings.append(query_embeddings[q])\n",
    "        adapter_doc_embeddings.append(retrieved_embeddings[q][d])\n",
    "        adapter_labels.append(evaluate_results(query, document))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c65337e9-85ee-47f7-89fd-7fe77cd0e1b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(adapter_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "babe7893-9cbc-43c5-94ef-cbf8f5d68cf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_query_embeddings = torch.Tensor(np.array(adapter_query_embeddings))\n",
    "adapter_doc_embeddings = torch.Tensor(np.array(adapter_doc_embeddings))\n",
    "adapter_labels = torch.Tensor(np.expand_dims(np.array(adapter_labels),1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a9524b-1085-4bdf-a161-39f11397dc1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torch.utils.data.TensorDataset(adapter_query_embeddings, adapter_doc_embeddings, adapter_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdac56f2-b98c-46ab-9cc4-248b79177ef8",
   "metadata": {},
   "source": [
    "## Setting up the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b26a01a-4575-446b-b8dc-a8c5ab153172",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(query_embedding, document_embedding, adaptor_matrix):\n",
    "    updated_query_embedding = torch.matmul(adaptor_matrix, query_embedding)\n",
    "    return torch.cosine_similarity(updated_query_embedding, document_embedding, dim=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0950575b-b69d-46a3-8c91-c7af89f5c204",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mse_loss(query_embedding, document_embedding, adaptor_matrix, label):\n",
    "    return torch.nn.MSELoss()(model(query_embedding, document_embedding, adaptor_matrix), label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f123ad8-b2e8-4a25-8b42-a520ecaf566b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the adaptor matrix\n",
    "mat_size = len(adapter_query_embeddings[0])\n",
    "adapter_matrix = torch.randn(mat_size, mat_size, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c04587-d1de-419c-a213-2e3eb67dc33d",
   "metadata": {},
   "outputs": [],
   "source": [
    "min_loss = float('inf')\n",
    "best_matrix = None\n",
    "\n",
    "for epoch in tqdm(range(100)):\n",
    "    for query_embedding, document_embedding, label in dataset:\n",
    "        loss = mse_loss(query_embedding, document_embedding, adapter_matrix, label)\n",
    "\n",
    "        if loss < min_loss:\n",
    "            min_loss = loss\n",
    "            best_matrix = adapter_matrix.clone().detach().numpy()\n",
    "\n",
    "        loss.backward()\n",
    "        with torch.no_grad():\n",
    "            adapter_matrix -= 0.01 * adapter_matrix.grad\n",
    "            adapter_matrix.grad.zero_()\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3155972-824e-4ebe-a692-2227c113c5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Best loss: {min_loss.detach().numpy()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8144a4a-85f6-4800-87f9-36a1b6ceda1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_vector = torch.ones((mat_size,1))\n",
    "scaled_vector = np.matmul(best_matrix, test_vector).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ff0b18e-12a0-4ac0-97dd-8618b22e7dbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.bar(range(len(scaled_vector)), scaled_vector.flatten())\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03ca7e7c-4b47-4652-9b46-a40b3dffa5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_embeddings = embedding_function(generated_queries)\n",
    "adapted_query_embeddings = np.matmul(best_matrix, np.array(query_embeddings).T).T\n",
    "\n",
    "projected_query_embeddings = project_embeddings(query_embeddings, umap_transform)\n",
    "projected_adapted_query_embeddings = project_embeddings(adapted_query_embeddings, umap_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74e7d67-7f51-41c4-8e25-edbaa02d0bd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the projected query and retrieved documents in the embedding space\n",
    "plt.figure()\n",
    "plt.scatter(projected_dataset_embeddings[:, 0], projected_dataset_embeddings[:, 1], s=10, color='gray')\n",
    "plt.scatter(projected_query_embeddings[:, 0], projected_query_embeddings[:, 1], s=150, marker='X', color='r', label=\"original\")\n",
    "plt.scatter(projected_adapted_query_embeddings[:, 0], projected_adapted_query_embeddings[:, 1], s=150, marker='X', color='green', label=\"adapted\")\n",
    "\n",
    "plt.gca().set_aspect('equal', 'datalim')\n",
    "plt.title(\"Adapted Queries\")\n",
    "plt.axis('off')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9188e886-d406-406f-b234-f5c3353a77a2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d3bb286-2694-4ed4-8466-46865e997ced",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2876084b-4038-4b0c-8ec8-8294a86adfc1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac542e1-b094-431f-9611-cf7e36d3f0de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcd6114b-c09d-4173-a623-9a08aaf63e4b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad10ab65-b351-4f4b-b7d2-63474acfb9f9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800f3d81-cbdb-4ba4-8d49-85747fdfded8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37847448-c9f6-4f51-bf06-f7809964a8b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dcefc87-0964-4b94-946b-2145781ad606",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fc994bc-7b1e-476a-9df9-300a3e374882",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ef5f5d5-acb7-4b0a-93ef-e61306708e69",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44e4b33f-d8fb-4f3a-b884-8b43a3766583",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2a480a2-2c29-4a01-80dd-ee41934b7901",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8127c2bf-0d15-4b62-b46a-f7a17ad2ec92",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18ded129-a637-4269-a116-550fe9a90570",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d7ee44-7b29-483f-a3f2-cc9d8e18880e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e450dd8-9719-42c6-8c3c-33cac910e0a5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
